import torch
import time
import model
import torch.optim
import torch.nn
from dataset import ParkingSlotDataset
from torch.utils.data import DataLoader
import myloss
import os

os.environ["CUDA_VISIBLE_DEVICS"] = "0,1"
torch.multiprocessing.set_sharing_strategy('file_system')

# dataset_path =  study\\code\\mycode_3-16\\weights\\"
weight_name = "weight23_0.0000768_0.0343446_end.pth"

weight_path = "/home/sang/公共的/skq/mycode_3-16/weights/"
dataset_path = "/skq/data-input/output/val/"
# weight_path = "F:\\study\\code\\mycode_3-16\\weights\\"
# dataset_path = "G:/0A/output/data-input/output/val/"

# dataset_path = "F:\\study\\mydataset"

batch_size = 80
num_workers = 20


def val(tmodel=None, batch_size=batch_size*2, num_workers=num_workers,
        weight_path=weight_path+weight_name, val_path=dataset_path,
        device=torch.device('cuda:0')):
    if tmodel is None:
        torch.set_grad_enabled(False)
        net = model.PSNet()
        net = torch.nn.DataParallel(net, device_ids=[0])
        net = net.to(device)
        net.module.load_state_dict(torch.load(weight_path, map_location="cuda:0"), strict=False)
    else:
        net = tmodel
    net.eval()
    data = ParkingSlotDataset(val_path)
    data_loader = DataLoader(data,
                             batch_size=batch_size, shuffle=True,
                             num_workers=num_workers,
                             collate_fn=lambda x: list(zip(*x)))
    Loss = myloss.C_Loss()
    all_loss_pos = 0
    all_loss_angle = 0
    aver_loss_pos = 0
    aver_loss_angle = 0
    num_batch = int(data_loader.batch_sampler.sampler.num_samples / data_loader.batch_size)
    num_points = 0
    start = 0
    for iter_idx, (images, target_pos, target_angle, mark_points) in enumerate(data_loader):
        with torch.no_grad():

            u_start = time.time()
            images = torch.stack(images).to(device)
            target_pos = torch.stack(target_pos)
            target_pos = target_pos.to(device)
            target_angle = torch.stack(target_angle).to(device)
            u_end = time.time()
            up_time = u_end-u_start
            output_pos, output_angle = net(images)
            n_end = time.time()
            net_time = n_end-u_end
            loss_pos, loss_angle, gradient = Loss(output_pos, output_angle, target_pos, target_angle, mark_points, device)
            l_end = time.time()
            l_time = l_end - n_end
            for i in range(images.shape[0]):
                num_points += len(mark_points[i])
            all_loss_pos += loss_pos.item()
            loss_angle_item = torch.sum(torch.mul(loss_angle, gradient)).item()
            all_loss_angle += loss_angle_item
            aver_loss_pos = all_loss_pos / (iter_idx + 1)
            aver_loss_angle = all_loss_angle / num_points
            end = time.time()
            a_time = end - l_end
            iter_time = end-start
            af_time = (len(data_loader)-iter_idx)*iter_time/3600
            start = time.time()
            print("aver_loss_pos:{:.7f}  aver_loss_angle:{:.7f} iter:{}/{}    time:{:.4f} up_time:{:.4f} net_time:{:.4f} l_time:{:.4f} a_time:{:.4f} 剩余：{:.4f}"
                  .format(aver_loss_pos, aver_loss_angle, iter_idx, len(data_loader), iter_time, up_time, net_time, l_time, a_time, af_time))
    print("aver_loss_pos:{:.7f}  aver_loss_angle:{:.7f} "
          .format(aver_loss_pos, aver_loss_angle))
    return aver_loss_pos, aver_loss_angle

if __name__ == '__main__':
    val()